import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def lambda_decay(steps, total_steps):
    factor = (total_steps-steps)/total_steps  
    return factor